#!/usr/bin/env python
# -*- coding: utf-8 -*-
#from __future__ import division, with_statement
'''
Copyright 2013, 陈同 (chentong_biology@163.com).  
===========================================================
'''
__author__ = 'chentong & ct586[9]'
__author_email__ = 'chentong_biology@163.com'
#=========================================================
desc = '''
Functional description:
    This is designed to get the average expression of genes from
    an expression matrix ususlly generated by DESeq2.sh.

Input file:

    1. Expr matrix output from DESeq2.sh
    
    2. Annotation file (with the first column matching the first
    column of file 1 as listed above, optional)

    3. sampleFile (header line needed)
        Samp    conditions
        SC_1    SC
        SC_2    SC
        SC_3    SC
        SC_11bA.SS_1    SC_11bA.SS
        SC_11bA.SS_2    SC_11bA.SS
        SC_can.SS_1     SC_can.SS
        SC_can.SS_2     SC_can.SS
        SG_1    SG
        SG_2    SG
        SG_3    SG
        SG_bam_1        SG_bam
        SG_bam_2        SG_bam


'''

import sys
import os
from json import dumps as json_dumps
from time import localtime, strftime 
timeformat = "%Y-%m-%d %H:%M:%S"
from optparse import OptionParser as OP
#from multiprocessing.dummy import Pool as ThreadPool
debug = 0
from math import log as ln

def fprint(content):
    print(json_dumps(content,indent=1))

def cmdparameter(argv):
    if len(argv) == 1:
        global desc
        print(desc, file=sys.stderr)
        cmd = 'python ' + argv[0] + ' -h'
        os.system(cmd)
        sys.exit(1)
    usages = '''%prog -e expr_matrix -s sampleFile -a anno >output

%prog -e expr_matrix -s sampleFile >output
'''
    parser = OP(usage=usages)
    parser.add_option("-e", "--expr-matrix", dest="expr",
        metavar="EXPR-MATRIX", help="The output of DESeq2.sh.")
    parser.add_option("-s", "--sampleFile", dest="samp",
        metavar="sampleFile", help="The sampleFile given to DESeq2.sh.")
    parser.add_option("-a", "--anno", dest="anno",
        help="Annotation file \
(Trinotate_annotation_report.simplify.xls) [optional]")
    parser.add_option("-c", "--anno-key-column", dest="key_c",
        default=1, help="Specify the column containing IDs. \
Default 1 meaning the first column (gene name). \
Accept 2 indicating the second column (transcript name).")
    parser.add_option("-v", "--verbose", dest="verbose",
        default=0, help="Show process information")
    parser.add_option("-d", "--debug", dest="debug",
        default=False, help="Debug the program")
    (options, args) = parser.parse_args(argv[1:])
    assert options.expr != None, "A filename needed for -e"
    return (options, args)
#--------------------------------------------------------------------
def readSamp(samp):
    header = 1
    sampD = {}
    for line in open(samp):
        if header:
            header -= 1
            continue
        #-----------------------
        key, value = line.strip().split()
        assert key not in sampD, "Duplicate keys %s" % key
        sampD[key] = value
    #-------------------------------
    return sampD
#----------------------------------

def readMatrix(expr):
    if debug:
        start_count = 1
    header = 1
    matrixD = {}
    for line in open(expr):
        lineL = line.strip().split('\t')
        key = lineL[0]
        if header:
            header -= 1
            key = 'head'
            headerL = lineL
            matrixD[key] = headerL[1:]
            continue
        #-----------------------
        assert key not in matrixD, "Duplicate %s" % key
        matrixD[key] = {}
        lenLineL = len(lineL)
        for i in range(1, lenLineL):
            matrixD[key][headerL[i]] = lineL[i]
        if debug:
            if start_count < 10:
                print("Key\tvalue", key, lineL, file=sys.stderr)
                #print >>sys.stderr, matrixD
            start_count += 1
    #---------------------------------
    return matrixD
#------------------------------------------

def readAnno(anno, key_c=0):
    annoD = {}
    if not anno:
        return annoD
    header = 1
    for line in open(anno):
        lineL = line.split('\t', key_c+1)
        key = lineL[key_c]
        if header:
            key = 'head'
            header -= 1
        #assert key not in annoD, key
        if key in annoD:
            if len(line.strip()) > len(annoD[key]):
                annoD[key] = line.strip()
        else:
            annoD[key] = line.strip()
    return annoD
#-------------------------------------

#def output(aDict, matrixD, prefix, annoD):
#    '''
#    aDict = {'compare': {
#                'samp':[]
#                'id':[]}   
#            }
#    matrixD = {'head':[],
#             'gene':{'samp':expr, 'samp2':expr2}}
#    '''
#    for key, valueD in aDict.items():
#        output = prefix+key+".results"
#        if annoD:
#            output_anno = prefix+key+".anno.xls"
#            anno_fh = open(output_anno, 'w')
#        fh = open(output, 'w')
#        sampleL = valueD['samp']
#        idL = valueD['id']
#        headerL = matrixD['head']
#        existL = [samp for samp in headerL if samp in sampleL]
#        print >>fh, "%s\t%s" % ('gene', '\t'.join(existL))
#        #print >>sys.stderr, matrixD['CG11790']
#        if annoD:
#            print >>anno_fh, "%s\t%s\t%s" % \
#                ('gene', '\t'.join(existL), annoD['head'])
#        for id in idL:
#            print >>fh, "%s\t%s" % (id, '\t'.join(\
#                [matrixD[id][samp] for samp in existL]))
#            if annoD:
#                print >>anno_fh, "%s\t%s\t%s" % (id, '\t'.join(\
#                    [matrixD[id][samp] for samp in existL]),
#                    annoD.get(id, ""))
#            #------END one id of each file------------
#        #----------END each item-----------------------
#        fh.close()
#        anno_fh.close()
##---------------------------------------------

def computeShannon(aList, plus=1):
    #print aList
    if len(aList) < 2:
        print("You may need to specify \
-I parameter if the program stops.", file=sys.stderr)
    expr = [float(i)+1 for i in aList]
    expr_sum = sum(expr)
    assert expr_sum != 0
    expr_R = [1.0 * i / expr_sum for i in expr]
    expr_Log = []
    for i in expr_R:
        if i != 0:
            expr_Log.append(i*ln(i)/ln(2))
        else:
            expr_Log.append(i)
    shannon = -1 * sum(expr_Log)
    return shannon
#-------------------------------

#def sortExprMatrix(exprMatrix):
#    '''
#    exprMatrix = [['id', expr, anno], 
#                  ['id', expr, anno]]
#    '''

#----END sortExprMatrix-------------------------

def output(aDict, matrixD, prefix, annoD):
    '''
    aDict = {'compare': {
                'samp':[]
                'id':[]}   
            }
    matrixD = {'head':[],
             'gene':{'samp':expr, 'samp2':expr2}}
    '''
    for key, valueD in list(aDict.items()):
        output = prefix+key+".results"
        if annoD:
            output_anno = prefix+key+".anno.xls"
            anno_fh = open(output_anno, 'w')
        fh = open(output, 'w')
        sampleL = valueD['samp']
        idL = valueD['id']
        headerL = matrixD['head']
        existL = [samp for samp in headerL if samp in sampleL]
        print("%s\t%s\tshannonEntropy" % ('gene', '\t'.join(existL)), file=fh)
        #print >>sys.stderr, matrixD['CG11790']
        if annoD:
            print("%s\t%s\tshannonEntropy\t%s" % \
                ('gene', '\t'.join(existL), annoD['head']), file=anno_fh)
        exprMatrix = []
        for id in idL:
            tmpL = [id]
            exprL = [matrixD[id][samp] for samp in existL]
            tmpL.append('\t'.join(exprL))
            shannon = "%.3f" % computeShannon(exprL)
            tmpL.append(str(shannon))
            if annoD:
                tmpL.append(annoD.get(id, ""))
            #------END one id of each file------------
            exprMatrix.append(tmpL)
        #----------END each item-----------------------
        #exprMatrix = sortExprMatrix(exprMatrix)
        exprMatrix.sort(key=lambda x: x[2])
        for tmpL in exprMatrix:
            print('\t'.join(tmpL[:-1]), file=fh)
        #print >>fh, '\n'.join(exprMatrix)
        if annoD:
            print('\n'.join(['\t'.join(tmpL) for tmpL in exprMatrix]), file=anno_fh)
        fh.close()
        anno_fh.close()
#---------------------------------------------


def main():
    options, args = cmdparameter(sys.argv)
    #-----------------------------------
    #-----------------------------------------
    expr = options.expr
    samp = options.samp
    sampD = readSamp(samp)
    '''
    sampD = {'T0_1':'T0', 'T0_2':'T0', 'T1_1':'T1', 'T1_2':'T1' }
    '''
    anno = options.anno
    key_c = int(options.key_c) - 1
    annoD = readAnno(anno, key_c)
    
    verbose = options.verbose
    global debug
    debug = options.debug
    #-----------------------------------
    header = 1
    for line in open(expr):
        if header:
            headerL = line.rstrip().split("\t")
            headerL[0] = 'gene'
            len_line = len(headerL)
            #print headerL
            newHeaderL = ['gene']
            for i in headerL:
                item = sampD.get(i, i)
                if item not in newHeaderL:
                    newHeaderL.append(item)
            #print newHeaderL
            if annoD:
                print("%s\t%s" % ("\t".join(newHeaderL),
                    annoD['head']))
            else:
                print('\t'.join(newHeaderL))
            header -= 1
            continue
        lineL = line.strip().split('\t')
        key = lineL[0]
        exprD = {}
        for i in range(1, len_line):
            samp = sampD[headerL[i]]
            if samp not in exprD:
                exprD[samp] = [float(lineL[i])]
            else:
                exprD[samp].append(float(lineL[i]))
        #--------------------------------------------
        #print exprD
        exprL = [str(sum(exprD[samp])/len(exprD[samp])) \
            for samp in newHeaderL[1:]]
        if annoD:
            print("%s\t%s\t%s" % (key, '\t'.join(exprL),
                    annoD.get(key, '')))
        else:
            print("%s\t%s" % (key, '\t'.join(exprL)))
    #-------------END reading file----------
    #----close file handle for files-----
    #if file != '-':
    #    fh.close()
    #-----------end close fh-----------
    ###--------multi-process------------------
    #pool = ThreadPool(5) # 5 represents thread_num
    #result = pool.map(func, iterable_object)
    #pool.close()
    #pool.join()
    ###--------multi-process------------------
    if verbose:
        print("--Successful %s" % strftime(timeformat, localtime()), file=sys.stderr)

if __name__ == '__main__':
    startTime = strftime(timeformat, localtime())
    main()
    endTime = strftime(timeformat, localtime())
    fh = open('python.log', 'a')
    print("%s\n\tRun time : %s - %s " % \
        (' '.join(sys.argv), startTime, endTime), file=fh)
    fh.close()
    ###---------profile the program---------
    #import profile
    #profile_output = sys.argv[0]+".prof.txt")
    #profile.run("main()", profile_output)
    #import pstats
    #p = pstats.Stats(profile_output)
    #p.sort_stats("time").print_stats()
    ###---------profile the program---------


